import logging
import wrapt
import torch
import megatron
from megatron.core import mpu
import megatron.core.parallel_state as parallel_state
from megatron.training.training import setup_model_and_optimizer
from epx.optim import epx_wrap_optimizer_instance

from megatron.core.optimizer import ChainedOptimizer
from megatron.core.optimizer.distrib_optimizer import DistributedOptimizer

logger = logging.getLogger(__name__)

# Do not use core.transformer.module.Float16Module now
# from megatron.core.transformer.module import Float16Module
# setattr(sys.modules["megatron.training.training"], "Float16Module", Float16Module)

@wrapt.decorator
def setup_model_and_optimizer_wrapper(wrapped, _, args, kwargs):
    def _dump_state_dict():
        nonlocal optimizer

        state_dict = {}

        if len(model) == 1:
            state_dict['model'] =  model[0].state_dict_for_save_checkpoint()
        else:
            for i in range(len(model)):
                mpu.set_virtual_pipeline_model_parallel_rank(i)
                state_dict['model%d' % i] =  model[i].state_dict_for_save_checkpoint()

        if opt_param_scheduler is not None:
            state_dict['opt_param_scheduler'] = opt_param_scheduler.state_dict()

        # Optimizer stuff.
        if optimizer is not None and not optimizer.is_stub_optimizer:
            state_dict["optimizer"] = []
            if isinstance(optimizer, ChainedOptimizer):
                for optim in optimizer.chained_optimizers:
                    inner_state_dict = optim.optimizer.state_dict()
                    shard_fp32_from_float16_groups = optim.shard_fp32_from_float16_groups
                    state_dict["optimizer"].append({ "inner_state_dict" : inner_state_dict,
                                                        "shard_fp32_from_float16_groups": shard_fp32_from_float16_groups})
            elif isinstance(optimizer, DistributedOptimizer):
                inner_state_dict = optimizer.optimizer.state_dict()
                shard_fp32_from_float16_groups = optimizer.shard_fp32_from_float16_groups
                state_dict["optimizer"] = { "inner_state_dict" : inner_state_dict,
                                            "shard_fp32_from_float16_groups": shard_fp32_from_float16_groups}
            else:
                assert False, f"epx _dump_state_dict not support {optimizer} now."

        return state_dict


    def _load_state_dict(state_dict):
        nonlocal optimizer
        opt_param_scheduler.load_state_dict(state_dict["opt_param_scheduler"])

        if len(model) == 1:
            model[0].load_state_dict(state_dict["model"])
        else:
            for i in range(len(model)):
                mpu.set_virtual_pipeline_model_parallel_rank(i)
                model[i] = state_dict['model%d' % i]

        if optimizer is not None and not optimizer.is_stub_optimizer:
            if isinstance(optimizer, ChainedOptimizer):
                optimizer_states = state_dict["optimizer"]
                assert len(optimizer_states) == len(optimizer.chained_optimizers), "optimizer state size mismatch"
                for optim, state in zip(optimizer.chained_optimizers, optimizer_states):
                    optim.optimizer.load_state_dict(state["inner_state_dict"])
                    _copy_shard_params(state["shard_fp32_from_float16_groups"], optim.shard_fp32_from_float16_groups)
            elif isinstance(optimizer, DistributedOptimizer):
                optimizer.optimizer.load_state_dict(state_dict["optimizer"]["inner_state_dict"])
                _copy_shard_params(state_dict["optimizer"]["shard_fp32_from_float16_groups"], optimizer.shard_fp32_from_float16_groups)
            else:
                assert False, f"epx _load_state_dict not support {optimizer} now."


    def _copy_shard_params(src_params, dst_params):
        assert len(src_params) == len(dst_params), "param size mismatch"
        for src, dst in zip(src_params, dst_params):
            if src is None or dst is None:
                continue

            if isinstance(src, list) and isinstance(dst, list):
                _copy_shard_params(src, dst)
                continue

            assert isinstance(src, torch.Tensor) and isinstance(dst, torch.Tensor), "param type mismatch"
            assert src.shape == dst.shape, "param shape mismatch"
            assert src.dtype == dst.dtype, "param dtype mismatch"
            dst.data.copy_(src.data)

    logger.info("epx wrapped setup_model_and_optimizer")

    model, optimizer, opt_param_scheduler = wrapped(*args, **kwargs)
    lcp = parallel_state.get_epx_data_parallel_lcp()

    logger.info(f"epx register replica_state")
    lcp.register_module("replica_state", _dump_state_dict, _load_state_dict)

    logger.info(f"Start wrap optimizer by epx")

    optimizer = epx_wrap_optimizer_instance(optimizer, lcp)

    logger.info(f"Finished wrap optimizer by epx")

    return model, optimizer, opt_param_scheduler

wraped_setup_model_and_optimizer = setup_model_and_optimizer_wrapper(setup_model_and_optimizer)

megatron.training.training.setup_model_and_optimizer = wraped_setup_model_and_optimizer
